% FACE RECOGNITION USING D-KSVD
close all
clearvars
clc

%% LOAD IMAGES
% 10 persons in set and 20 images per person
% 10 images used for training and 10 images used for testing

addpath('utilities')
addpath('data')

for personID = 1:10
    for imageID = 1:20
        fileName = sprintf('Y:/Projects/MATLAB Projects/Sparse Dictionary Learning/data/face recognition/malestaff/%d/%02d.jpg', personID, imageID);
        image = rgb2gray(imread(fileName));
        image = im2double(image);
        images(:, :, imageID, personID) = image;
        
%         figure(1)
%         imagesc(images(:, :, personID, imageID))
%         drawnow
    end
end

nTraining = 10;
nTesting = 10;
 
images_training = images(:, :, 1:nTraining, :);
images_testing = images(:, :, nTraining+1:nTraining+nTesting, :);

images_training = reshape(images_training, size(images_training, 1)*size(images_training, 2), size(images_training, 3)* size(images_training, 4));
images_testing = reshape(images_testing, size(images_testing, 1)*size(images_testing, 2), size(images_testing, 3)* size(images_testing, 4));

images_training = substractMeanCols(images_training);
images_testing = substractMeanCols(images_testing);


% for personID = 1:10
%     for imageID = 1:10
%         personID, imageID
%     figure(1)
%     subplot(121)
%     imagesc(reshape(images_training(:,10*(personID-1)+1+imageID-1), 200,180))
%     subplot(122)
%     imagesc(reshape(images_testing(:,10*(personID-1)+1+imageID-1),  200,180))
%     waitforbuttonpress
%     end
% end

% randomfaces matrix size
% random faces is linear projection generated by Gaussian random mask
n = 756;

R = randn(n, size(images, 1)*size(images,2));
R = normalizeColumns(R')';

Y0 = R*images_training;

K = 100;

D0 = initDictionaryFromPatches(n, K, Y0);

% X = OMP(D0,images_testing(:,2), T0)

%%

niter_learn = 20;
niter_coeff = 30;
niter_dict = 10;
T0 = 10;

D_cat = [];

% initialize class label matrix 
H = kron(diag(ones(10,1)), ones(size(Y0,2)/10, 1))';

param.K = K/10;
param.numIteration=niter_learn;
param.InitializationMethod='DataElements';
param.preserveDCAtom=0;
param.L = T0;

for i = 1:10
    
    Y_part = Y0(:,(i-1)*10+1:i*10);
%     D_part = D0(:,(i-1)*10+1:i*10);
    
    D_part = initDictionaryFromPatches(n, K/10, Y_part);

%     [D_part, out] = KSVD(Y_part, param);
    [D_part, X] = learnDictionary(Y_part, D_part, 5, 'nIterLearn', 5);

    D_cat = [D_cat, D_part];
end


%%

X = OMP(D_cat,Y0,T0);
% X = sparseCode(Y0, D_cat, T0, 10);

% initialize linear classifier W
W0 = (H*X')/(X*X'+eye(size(X*X')));

% factor that controls reconstructive/discriminative dictionary properties
gamma = 1000;

D = [D0; sqrt(gamma)*W0];
Y = [Y0; sqrt(gamma)*H];

param.K = K;
param.numIteration = 10;
param.InitializationMethod = 'GivenMatrix';
param.initialDictionary = D;
param.preserveDCAtom = 0;
param.L = T0;

[D, out] = KSVD(Y, param);

D_final = (D(1:n,:))./(sqrt(sum(abs(D(1:n,:).^2),1)));
W_final = (D(n+1:end,:))./(sqrt(sum(abs(D(1:n,:).^2),1)));

%%
Y0 = R*images_testing;
% 
% Y0 = R*images_training;

X = zeros(size(D, 2), size(Y0, 2));

X = OMP(D_final, Y0, 50);

classCode = (W_final*X);

[c, i] = max(abs(classCode))

figure
imagesc(i)
